--- train_layout.py	2023-04-06 03:36:13.157714906 +0000
+++ external_reference/train_layout.py	2023-04-06 03:32:01.086219550 +0000
@@ -17,8 +17,8 @@
 import torch
 
 import dnnlib
-from training import training_loop
-from metrics import metric_main
+from external.stylegan.training import training_loop
+from external.stylegan.metrics import metric_main
 from torch_utils import training_stats
 from torch_utils import custom_ops
 
@@ -101,7 +101,10 @@
 
 def init_dataset_kwargs(data):
     try:
-        dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)
+        # add a temporary resolution for kwargs, will be overridden later
+        dataset_kwargs = dnnlib.EasyDict(
+            class_name='external.stylegan.training.dataset.ImageFolderDataset',
+            path=data, resolution=64, use_labels=True, max_size=None, xflip=False)
         dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
         dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
         dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
@@ -161,6 +164,62 @@
 @click.option('--workers',      help='DataLoader worker processes', metavar='INT',              type=click.IntRange(min=1), default=3, show_default=True)
 @click.option('-n','--dry-run', help='Print training options and exit',                         is_flag=True)
 
+## ADDED: additional training options
+# which part of the model is being trained
+@click.option('--training-mode', help='which part of the model is being trained', type=str, required=True)
+
+# dataset options
+@click.option('--pose', help='pose path', metavar='[ZIP|DIR]', type=str)
+@click.option('--img-resolution', help='image resolution', type=click.IntRange(min=1), default=256, show_default=True)
+@click.option('--depth-scale', help='depth scale factor', metavar='FLOAT', type=click.FloatRange(min=0), default=16, show_default=True)
+@click.option('--depth-clip', help='depth clip factor', metavar='FLOAT', type=click.FloatRange(min=0), default=20, show_default=True)
+@click.option('--use-disp', help='train using disparity', metavar='BOOL', type=bool, default=True, show_default=True)
+@click.option('--fov-mean', help='mean fov', metavar='FLOAT', type=click.FloatRange(min=0,max=180), default=60, show_default=True)
+@click.option('--fov-std', help='std fov', metavar='FLOAT', type=click.FloatRange(min=0,max=180), default=0, show_default=True)
+
+# layout generator options
+@click.option('--z-dim', help='mapping layers input dimension', type=click.IntRange(min=1), default=128, show_default=True)
+@click.option('--num-layers', help='number of generator layers', type=click.IntRange(min=1), default=6, show_default=True)
+@click.option('--num-decoder-ch', help='number of layout features', type=click.IntRange(min=1), default=32, show_default=True)
+@click.option('--voxel-res', help='voxel res of decoder', type=click.IntRange(min=1), default=256, show_default=True)
+@click.option('--voxel-size', help='voxel size', metavar='FLOAT', type=click.FloatRange(min=0), default=0.15, show_default=True)
+
+# layout decoder options
+@click.option('--feature-nerf', help='does not apply activation to rgb', metavar='BOOL', type=bool, default=True, show_default=True)
+@click.option('--nerf-out-res', help='output resolution of nerf', type=click.IntRange(min=1), default=32, show_default=True)
+@click.option('--nerf-out-ch', help='output channels of nerf', type=click.IntRange(min=1), default=128, show_default=True)
+@click.option('--nerf-n-layers', help='layers in nerf mlp', type=click.IntRange(min=1), default=8, show_default=True)
+@click.option('--nerf-far', help='nerf far bound', metavar='FLOAT', type=click.FloatRange(min=1), default=16., show_default=True)
+@click.option('--nerf-samples-per-ray', help='nerf samples per ray', type=click.IntRange(min=1), default=128, show_default=True)
+
+# layout: loss arguments
+@click.option('--concat-depth', help='depth to discriminator', metavar='BOOL', type=bool, default=True, show_default=True)
+@click.option('--concat-acc', help='acc to discriminator', metavar='BOOL', type=bool, default=False, show_default=True)
+@click.option('--recon-weight', help='recon weight', metavar='FLOAT', type=click.FloatRange(min=0), default=1000., show_default=True)
+@click.option('--aug-policy', help='aug policy', type=str, default='translation,color,cutout', required=True)
+@click.option('--d-cbase', help='Capacity multiplier', metavar='INT', type=click.IntRange(min=1), default=32768, show_default=True)
+@click.option('--d-cmax', help='Max. feature maps', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
+@click.option('--ray-lambda-finite-difference', help='recon weight', metavar='FLOAT', type=click.FloatRange(min=0), default=0., show_default=True)
+@click.option('--ray-lambda-ramp-end', help='ramp function end (kimg)', type=click.IntRange(min=1), default=1000, show_default=True)
+@click.option('--use-wrapped-discriminator', help='wrap discriminator with second acc discriminator', type=bool, default=False, show_default=True)
+
+# upsampler model: architecture
+@click.option('--input-resolution', help='upsampler input resolution', type=click.IntRange(min=1), default=128, show_default=True)
+@click.option('--concat-depth-and-acc', help='upsampler also generates depth, acc', metavar='BOOL', type=bool, default=True, show_default=True)
+@click.option('--use-3d-noise', help='sgan2 3D noise in upsampler', metavar='BOOL', type=bool, default=True, show_default=True)
+@click.option('--layout-model', help='layout model path', type=str)
+
+# upsampler model: losses
+@click.option('--lambda-rec', help='recon weight', metavar='FLOAT', type=click.FloatRange(min=0), default=0., show_default=True)
+@click.option('--lambda-up', help='upsample recon weight', metavar='FLOAT', type=click.FloatRange(min=0), default=0., show_default=True)
+@click.option('--lambda-gray-pixel', help='penalty on gray pixels where acc=1', metavar='FLOAT', type=click.FloatRange(min=0), default=0., show_default=True)
+@click.option('--lambda-gray-pixel-falloff', help='penalty on gray pixels; exponential falloff rate', metavar='FLOAT', type=click.FloatRange(min=0), default=20, show_default=True)
+@click.option('--D-ignore-depth-acc', help='ignores depth, acc inputs to discriminator', metavar='BOOL', type=bool, default=True, show_default=True)
+
+# additional arguments for sky model
+@click.option('--mask-prob', help='probability to use gt rgb in sky generator output', type=float)
+
+
 def main(**kwargs):
     """Train a GAN using the techniques described in the paper
     "Alias-Free Generative Adversarial Networks".
@@ -187,19 +246,33 @@
     # Initialize config.
     opts = dnnlib.EasyDict(kwargs) # Command line arguments.
     c = dnnlib.EasyDict() # Main config dict.
-    c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict())
-    c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
+    c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=opts.z_dim, w_dim=opts.z_dim, mapping_kwargs=dnnlib.EasyDict())
+    c.D_kwargs = dnnlib.EasyDict(class_name='external.stylegan.training.networks_stylegan2_terrain.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
     c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
     c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
-    c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss')
+    c.loss_kwargs = dnnlib.EasyDict(class_name='external.stylegan.training.loss.StyleGAN2Loss')
     c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)
 
+    # check the training mode
+    assert(opts.training_mode in ['layout', 'upsampler', 'sky'])
+
     # Training set.
     c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data)
     if opts.cond and not c.training_set_kwargs.use_labels:
         raise click.ClickException('--cond=True requires labels specified in dataset.json')
     c.training_set_kwargs.use_labels = opts.cond
     c.training_set_kwargs.xflip = opts.mirror
+    c.training_set_kwargs.resolution = opts.img_resolution
+    c.training_set_kwargs.depth_scale = opts.depth_scale
+    c.training_set_kwargs.depth_clip = opts.depth_clip
+    c.training_set_kwargs.use_disp = opts.use_disp
+    c.training_set_kwargs.fov_mean = opts.fov_mean
+    c.training_set_kwargs.fov_std = opts.fov_std
+    c.training_set_kwargs.pose_path = opts.pose
+    if opts.training_mode == 'layout':
+        # additional sanity checks
+        assert(opts.depth_scale == opts.nerf_far) # depth clip should match nerf far bound
+        assert(opts.use_disp == True) # train it on disparity (inverse depth)
 
     # Hyperparameters & settings.
     c.num_gpus = opts.gpus
@@ -232,15 +305,22 @@
 
     # Base configuration.
     c.ema_kimg = c.batch_size * 10 / 32
-    if opts.cfg == 'stylegan2':
-        c.G_kwargs.class_name = 'training.networks_stylegan2.Generator'
-        c.loss_kwargs.style_mixing_prob = 0.9 # Enable style mixing regularization.
-        c.loss_kwargs.pl_weight = 2 # Enable path length regularization.
-        c.G_reg_interval = 4 # Enable lazy regularization for G.
+    if opts.cfg == 'stylegan2' and opts.training_mode == 'layout':
+        c.G_kwargs.class_name = 'external.stylegan.training.networks_stylegan2_terrain.Generator'
+        c.loss_kwargs.style_mixing_prob = 0 # 0.9 # Enable style mixing regularization.
+        c.loss_kwargs.pl_weight = 0 # 2 # Enable path length regularization.
+        c.G_reg_interval = None # 4 # Enable lazy regularization for G.
         c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
         c.loss_kwargs.pl_no_weight_grad = True # Speed up path length regularization by skipping gradient computation wrt. conv2d weights.
-    else:
-        c.G_kwargs.class_name = 'training.networks_stylegan3.Generator'
+    elif opts.cfg == 'stylegan2' and opts.training_mode == 'upsampler':
+        c.G_kwargs.class_name = 'external.stylegan.training.networks_stylegan2_terrain.Generator'
+        c.loss_kwargs.style_mixing_prob = 0.0 # disabled; Enable style mixing regularization.
+        c.loss_kwargs.pl_weight = 2 # Enable path length regularization. 
+        c.G_reg_interval = 4 # Enable lazy regularization for G. 
+        c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
+        c.loss_kwargs.pl_no_weight_grad = True # Speed up path length regularization by skipping gradient computation wrt. conv2d weights.
+    elif opts.training_mode == 'sky':
+        c.G_kwargs.class_name = 'external.stylegan.training.networks_stylegan3_sky.Generator'
         c.G_kwargs.magnitude_ema_beta = 0.5 ** (c.batch_size / (20 * 1e3))
         if opts.cfg == 'stylegan3-r':
             c.G_kwargs.conv_kernel = 1 # Use 1x1 convolutions.
@@ -250,9 +330,82 @@
             c.loss_kwargs.blur_init_sigma = 10 # Blur the images seen by the discriminator.
             c.loss_kwargs.blur_fade_kimg = c.batch_size * 200 / 32 # Fade out the blur during the first N kimg.
 
+    # additional model arguments
+    c.training_mode = opts.training_mode
+    if opts.training_mode == 'layout':
+        c.G_kwargs.img_resolution = opts.voxel_res
+        c.G_kwargs.img_channels = opts.num_decoder_ch
+        nerf_z_dim = opts.num_decoder_ch
+        nerf_mlp_kwargs = dnnlib.EasyDict(
+            class_name='external.gsn.models.generator.NerfStyleGenerator',
+            n_layers=opts.nerf_n_layers,
+            channels=128,
+            out_channel=opts.nerf_out_ch,
+            z_dim=nerf_z_dim,
+        )
+        generator_res = opts.voxel_res
+        c.decoder_kwargs = dnnlib.EasyDict(
+            class_name='external.gsn.models.generator.SceneGenerator',
+            nerf_mlp_config=nerf_mlp_kwargs,
+            img_res=opts.img_resolution,
+            feature_nerf=opts.feature_nerf,
+            global_feat_res=generator_res,
+            coordinate_scale=generator_res * opts.voxel_size,
+            alpha_activation='softplus',
+            local_coordinates=True,
+            hierarchical_sampling=False,
+            density_bias=0,
+            zfar_bias=False,
+            use_disp=opts.use_disp,
+            nerf_out_res=opts.nerf_out_res,
+            samples_per_ray=opts.nerf_samples_per_ray,
+            near=1,
+            far=opts.nerf_far,
+            alpha_noise_std=0,
+        )
+        c.torgb_kwargs = dnnlib.EasyDict(
+            class_name='models.misc.networks.ToRGBTexture',
+            in_channel=opts.nerf_out_ch,
+        )
+        c.wrapper_kwargs = dnnlib.EasyDict(
+            voxel_res=generator_res, # opts.voxel_res,
+            voxel_size=opts.voxel_size,
+            img_res=opts.img_resolution,
+            fov_mean=opts.fov_mean,
+            fov_std=opts.fov_std,
+        )
+        c.loss_kwargs.loss_layout_kwargs = dnnlib.EasyDict(
+            concat_depth = opts.concat_depth,
+            concat_acc = opts.concat_acc,
+            recon_weight = opts.recon_weight,
+            aug_policy = opts.aug_policy,
+            lambda_finite_difference=opts.ray_lambda_finite_difference,
+            lambda_ramp_end = opts.ray_lambda_ramp_end * 1000, # convert to kimg
+            use_wrapped_discriminator = opts.use_wrapped_discriminator
+        )
+    elif opts.training_mode == 'upsampler':
+        c.G_kwargs.use_noise = True # helps improve texture
+        c.G_kwargs.default_noise_mode = '3dnoise' if opts.use_3d_noise else 'random'
+        c.G_kwargs.input_resolution = opts.input_resolution
+        c.G_kwargs.num_additional_feature_channels = 2 if opts.concat_depth_and_acc else 0
+        c.loss_kwargs.loss_upsampler_kwargs = dnnlib.EasyDict(
+            d_ignore_depth_acc = opts.d_ignore_depth_acc,
+            lambda_rec = opts.lambda_rec,
+            lambda_up = opts.lambda_up,
+            lambda_gray_pixel = opts.lambda_gray_pixel,
+            lambda_gray_pixel_falloff = opts.lambda_gray_pixel_falloff,
+        )
+        c.wrapper_kwargs = dnnlib.EasyDict(
+            layout_model_path = opts.layout_model,
+        )
+    elif opts.training_mode == 'sky':
+        c.loss_kwargs.loss_sky_kwargs = dnnlib.EasyDict(
+            mask_prob = opts.mask_prob
+        )
+
     # Augmentation.
     if opts.aug != 'noaug':
-        c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1)
+        c.augment_kwargs = dnnlib.EasyDict(class_name='external.stylegan.training.augment.AugmentPipe', xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1)
         if opts.aug == 'ada':
             c.ada_target = opts.target
         if opts.aug == 'fixed':
